{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "PARAMS = {\n",
    "    'max_len': 15,\n",
    "    'embed_dims': 128,\n",
    "    'rnn_size': 128,\n",
    "    'hidden_dim': 128,\n",
    "    'num_heads': 4,\n",
    "    'latent_size': 16,\n",
    "    'n_hidden_layer': 1,\n",
    "    'num_sampled': 1000,\n",
    "    'clip_norm': 5.0,\n",
    "    'anneal_max': 1.0,\n",
    "    'anneal_bias': 6000,\n",
    "    'batch_size': 128,\n",
    "    'n_epochs': 10,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_vocab(index_from=4):\n",
    "    PARAMS['word2idx'] = tf.keras.datasets.imdb.get_word_index()\n",
    "    PARAMS['word2idx'] = {k: (v + index_from) for k, v in PARAMS['word2idx'].items()}\n",
    "    PARAMS['word2idx']['<pad>'] = 0\n",
    "    PARAMS['word2idx']['<start>'] = 1\n",
    "    PARAMS['word2idx']['<unk>'] = 2\n",
    "    PARAMS['word2idx']['<end>'] = 3\n",
    "    \n",
    "    PARAMS['idx2word'] = {i: w for w, i in PARAMS['word2idx'].items()}\n",
    "    \n",
    "    PARAMS['vocab_size'] = len(PARAMS['word2idx'])\n",
    "\n",
    "    \n",
    "def load_data(index_from=4):\n",
    "    (X_train, _), (X_test, _) = tf.contrib.keras.datasets.imdb.load_data(\n",
    "        num_words=None, index_from=index_from)\n",
    "    return (X_train, X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = build_vocab()\n",
    "X = np.concatenate(load_data())\n",
    "\n",
    "X = np.concatenate((\n",
    "    tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        X, PARAMS['max_len'], truncating='post', padding='post'),\n",
    "    tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        X, PARAMS['max_len'], truncating='pre', padding='post')))\n",
    "\n",
    "enc_inp = X[:, 1:]\n",
    "dec_inp = X\n",
    "dec_out = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], PARAMS['word2idx']['<end>'])], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reparam_trick(z_mean, z_logvar):\n",
    "    gaussian = tf.truncated_normal(tf.shape(z_logvar))\n",
    "    z = z_mean + tf.exp(0.5 * z_logvar) * gaussian\n",
    "    return z\n",
    "\n",
    "\n",
    "def kl_w_fn(global_step):\n",
    "    return PARAMS['anneal_max'] * tf.sigmoid((10 / PARAMS['anneal_bias']) * (\n",
    "        tf.to_float(global_step) - tf.constant(PARAMS['anneal_bias'] / 2)))\n",
    "\n",
    "\n",
    "def kl_loss_fn(mean, gamma):\n",
    "    return 0.5 * tf.reduce_sum(\n",
    "        tf.exp(gamma) + tf.square(mean) - 1 - gamma) / tf.to_float(tf.shape(mean)[0])\n",
    "\n",
    "\n",
    "def clip_grads(loss):\n",
    "    variables = tf.trainable_variables()\n",
    "    grads = tf.gradients(loss, variables)\n",
    "    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])\n",
    "    return zip(clipped_grads, variables)\n",
    "\n",
    "\n",
    "def rnn_cell():\n",
    "    return tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],\n",
    "                                  kernel_initializer=tf.orthogonal_initializer())\n",
    "\n",
    "\n",
    "def embed_seq(x, vocab_sz, embed_dim, name, zero_pad=False, scale=False):\n",
    "    embedding = tf.get_variable(name, [vocab_sz, embed_dim])\n",
    "    if zero_pad:\n",
    "        embedding = tf.concat([tf.zeros([1, embed_dim]), embedding[1:, :]], 0)\n",
    "    x = tf.nn.embedding_lookup(embedding, x)\n",
    "    if scale:\n",
    "        x = x * np.sqrt(embed_dim)\n",
    "    return x\n",
    "\n",
    "\n",
    "def position_embedding(inputs):\n",
    "    T = inputs.get_shape().as_list()[1]\n",
    "    x = tf.range(T)                            # (T)\n",
    "    x = tf.expand_dims(x, 0)                   # (1, T)\n",
    "    x = tf.tile(x, [tf.shape(inputs)[0], 1])   # (N, T)\n",
    "    return embed_seq(x, T, PARAMS['hidden_dim'], 'position_embedding')\n",
    "\n",
    "\n",
    "def layer_norm(inputs, epsilon=1e-8):\n",
    "    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)\n",
    "    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))\n",
    "    \n",
    "    params_shape = inputs.get_shape()[-1:]\n",
    "    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())\n",
    "    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())\n",
    "    \n",
    "    return gamma * normalized + beta\n",
    "\n",
    "\n",
    "def self_attention(inputs, is_training):\n",
    "    num_units = PARAMS['hidden_dim']\n",
    "    num_heads = PARAMS['num_heads']\n",
    "    T_q = T_k = inputs.get_shape()[1].value\n",
    "\n",
    "    Q_K_V = tf.layers.dense(inputs, 3*num_units)\n",
    "    Q, K, V = tf.split(Q_K_V, 3, -1)\n",
    "    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), 0)                         \n",
    "    K_ = tf.concat(tf.split(K, num_heads, axis=2), 0)                        \n",
    "    V_ = tf.concat(tf.split(V, num_heads, axis=2), 0)                         \n",
    "\n",
    "    align = tf.matmul(Q_, K_, transpose_b=True)                               \n",
    "    align = align / np.sqrt(K_.get_shape().as_list()[-1])\n",
    "\n",
    "    paddings = tf.fill(tf.shape(align), float('-inf'))         \n",
    "    lower_tri = tf.ones([T_q, T_k])                                                \n",
    "    lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense()      \n",
    "    masks = tf.tile(tf.expand_dims(lower_tri,0), [tf.shape(align)[0],1,1])       \n",
    "    align = tf.where(tf.equal(masks, 0), paddings, align)               \n",
    "\n",
    "    align = tf.nn.softmax(align)                                                  \n",
    "    align = tf.layers.dropout(align, 0.1, training=is_training)           \n",
    "    x = tf.matmul(align, V_)                                                 \n",
    "    x = tf.concat(tf.split(x, num_heads, axis=0), 2)              \n",
    "    x += inputs                                                                \n",
    "    x = layer_norm(x)                                                 \n",
    "    return x\n",
    "\n",
    "\n",
    "def ffn(inputs, activation=tf.nn.relu):\n",
    "    x = tf.layers.conv1d(inputs, 4*PARAMS['hidden_dim'], 1, activation=activation)\n",
    "    x = tf.layers.conv1d(x, PARAMS['hidden_dim'], 1)\n",
    "    x += inputs\n",
    "    x = layer_norm(x)\n",
    "    return x\n",
    "\n",
    "\n",
    "def attn_forward(x, output_proj, is_training):\n",
    "    x += position_embedding(x)\n",
    "    x = tf.layers.dropout(x, 0.1, training=is_training)\n",
    "    for i in range(PARAMS['n_hidden_layer']):\n",
    "        with tf.variable_scope('attn_%d'%i):\n",
    "            x = self_attention(x, is_training)\n",
    "        with tf.variable_scope('ffn_%d'%i):\n",
    "            x = ffn(x)\n",
    "    return output_proj(x), x\n",
    "\n",
    "\n",
    "def autoregressive(embedding, z, input_proj, output_proj):\n",
    "    batch_sz = tf.shape(z)[0]\n",
    "    \n",
    "    def cond(i, x, temp):\n",
    "        return i < PARAMS['max_len']\n",
    "\n",
    "    def body(i, x, temp):\n",
    "        sos = tf.fill([batch_sz, 1], PARAMS['word2idx']['<start>'])\n",
    "        x = tf.concat([sos, x[:, :-1]], 1)\n",
    "        \n",
    "        x = tf.nn.embedding_lookup(embedding, x)\n",
    "        x = input_proj(tf.concat([x, z], -1))\n",
    "        logits = attn_forward(x, output_proj, is_training=False)[0]\n",
    "        ids = tf.argmax(logits, -1, output_type=tf.int32)[:, i]\n",
    "        ids = tf.expand_dims(ids, -1)\n",
    "\n",
    "        temp = tf.concat([temp[:, 1:], ids], -1)\n",
    "\n",
    "        x = tf.concat([temp[:, -(i+1):], temp[:, :-(i+1)]], -1)\n",
    "        x = tf.reshape(x, [batch_sz, PARAMS['max_len']])\n",
    "        i += 1\n",
    "        return i, x, temp\n",
    "    \n",
    "    x = tf.zeros([batch_sz, PARAMS['max_len']], tf.int32)\n",
    "    _, res, _ = tf.while_loop(cond, body, [tf.constant(0), x, x])\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(inputs, labels, mode):\n",
    "    is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
    "    enc_seq_len = tf.count_nonzero(inputs, 1, dtype=tf.int32)\n",
    "    batch_sz = tf.shape(inputs)[0]\n",
    "    \n",
    "    with tf.variable_scope('Encoder'):\n",
    "        embedding = tf.get_variable('lookup_table', [len(PARAMS['word2idx']), PARAMS['embed_dims']])\n",
    "        x = tf.nn.embedding_lookup(embedding, inputs)\n",
    "        \n",
    "        _, enc_state = tf.nn.dynamic_rnn(rnn_cell(), x, enc_seq_len, dtype=tf.float32)\n",
    "        \n",
    "        z_mean = tf.layers.dense(enc_state, PARAMS['latent_size'])\n",
    "        z_logvar = tf.layers.dense(enc_state, PARAMS['latent_size'])\n",
    "        \n",
    "    z = reparam_trick(z_mean, z_logvar)\n",
    "        \n",
    "    with tf.variable_scope('Decoder'):\n",
    "        input_proj = tf.layers.Dense(PARAMS['hidden_dim'], tf.nn.relu)\n",
    "        output_proj = tf.layers.Dense(len(PARAMS['word2idx']), _scope='decoder/output_proj')\n",
    "        z = tf.tile(tf.expand_dims(z, 1), [1, PARAMS['max_len'], 1])\n",
    "        \n",
    "        if is_training:\n",
    "            dec_inp = tf.nn.embedding_lookup(embedding, labels['dec_inp'])\n",
    "            x = input_proj(tf.concat([dec_inp, z], -1))\n",
    "            logits, x = attn_forward(x, output_proj, is_training)\n",
    "            return x, logits, (z_mean, z_logvar)\n",
    "        else:\n",
    "            return autoregressive(embedding, z, input_proj, output_proj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode):\n",
    "    logits_or_ids = forward(features, labels, mode)        \n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        return tf.estimator.EstimatorSpec(mode, predictions=logits_or_ids)\n",
    "        \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        attn_output, logits, (z_mean, z_logvar) = logits_or_ids\n",
    "        \n",
    "        global_step = tf.train.get_global_step()\n",
    "        \n",
    "        with tf.variable_scope('Decoder/decoder/output_proj', reuse=True):\n",
    "            _weights = tf.transpose(tf.get_variable('kernel'))\n",
    "            _biases = tf.get_variable('bias')\n",
    "        \n",
    "        mask = tf.reshape(tf.to_float(tf.sign(labels['dec_out'])), [-1])\n",
    "        \n",
    "        nll_loss = tf.reduce_sum(mask * tf.nn.sampled_softmax_loss(\n",
    "            weights = _weights,\n",
    "            biases = _biases,\n",
    "            labels = tf.reshape(labels['dec_out'], [-1, 1]),\n",
    "            inputs = tf.reshape(attn_output, [-1, PARAMS['hidden_dim']]),\n",
    "            num_sampled = PARAMS['num_sampled'],\n",
    "            num_classes = PARAMS['vocab_size'],\n",
    "        )) / tf.to_float(tf.shape(features)[0])\n",
    "        \n",
    "        kl_w = kl_w_fn(global_step)\n",
    "        \n",
    "        kl_loss = kl_loss_fn(z_mean, z_logvar)\n",
    "        \n",
    "        loss_op = nll_loss + kl_w * kl_loss\n",
    "        \n",
    "        train_op = tf.train.AdamOptimizer().apply_gradients(\n",
    "            clip_grads(loss_op),\n",
    "            global_step = global_step)\n",
    "        \n",
    "        lth = tf.train.LoggingTensorHook(\n",
    "            {'nll_loss': nll_loss, 'kl_w': kl_w, 'kl_loss': kl_loss}, every_n_iter=100)\n",
    "        \n",
    "        return tf.estimator.EstimatorSpec(\n",
    "            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i\n",
      "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11f0a56d8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:1344: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See @{tf.nn.softmax_cross_entropy_with_logits_v2}.\n",
      "\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 156.89622, step = 1\n",
      "INFO:tensorflow:nll_loss = 156.89622, kl_w = 0.006692851, kl_loss = 4.798962e-05\n",
      "INFO:tensorflow:global_step/sec: 2.53474\n",
      "INFO:tensorflow:loss = 85.32051, step = 101 (39.453 sec)\n",
      "INFO:tensorflow:nll_loss = 85.30536, kl_w = 0.007897083, kl_loss = 1.919153 (39.453 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.62346\n",
      "INFO:tensorflow:loss = 75.754295, step = 201 (38.118 sec)\n",
      "INFO:tensorflow:nll_loss = 75.6801, kl_w = 0.009315956, kl_loss = 7.964015 (38.118 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.66891\n",
      "INFO:tensorflow:loss = 71.74492, step = 301 (37.469 sec)\n",
      "INFO:tensorflow:nll_loss = 71.540115, kl_w = 0.010986943, kl_loss = 18.640404 (37.469 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.66942\n",
      "INFO:tensorflow:loss = 67.676384, step = 401 (37.460 sec)\n",
      "INFO:tensorflow:nll_loss = 67.37237, kl_w = 0.012953726, kl_loss = 23.469288 (37.460 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.77597\n",
      "INFO:tensorflow:loss = 71.23791, step = 501 (36.023 sec)\n",
      "INFO:tensorflow:nll_loss = 70.8468, kl_w = 0.015267149, kl_loss = 25.617533 (36.023 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.59824\n",
      "INFO:tensorflow:loss = 63.92441, step = 601 (38.488 sec)\n",
      "INFO:tensorflow:nll_loss = 63.430676, kl_w = 0.01798621, kl_loss = 27.450748 (38.488 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.65233\n",
      "INFO:tensorflow:loss = 65.2393, step = 701 (37.702 sec)\n",
      "INFO:tensorflow:nll_loss = 64.64574, kl_w = 0.021179108, kl_loss = 28.026161 (37.702 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 782 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 63.254906.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i would have to see this movie it is a good movie and bad <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is not a good and and that is not worth it out <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 783 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 61.99346, step = 783\n",
      "INFO:tensorflow:nll_loss = 61.28309, kl_w = 0.024205629, kl_loss = 29.347403\n",
      "INFO:tensorflow:global_step/sec: 2.5101\n",
      "INFO:tensorflow:loss = 61.349503, step = 883 (39.840 sec)\n",
      "INFO:tensorflow:nll_loss = 60.52833, kl_w = 0.028470589, kl_loss = 28.842724 (39.840 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.59544\n",
      "INFO:tensorflow:loss = 59.302696, step = 983 (38.529 sec)\n",
      "INFO:tensorflow:nll_loss = 58.32515, kl_w = 0.033461247, kl_loss = 29.214266 (38.529 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.61509\n",
      "INFO:tensorflow:loss = 61.513115, step = 1083 (38.239 sec)\n",
      "INFO:tensorflow:nll_loss = 60.43241, kl_w = 0.039291352, kl_loss = 27.504892 (38.239 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.70904\n",
      "INFO:tensorflow:loss = 59.887875, step = 1183 (36.913 sec)\n",
      "INFO:tensorflow:nll_loss = 58.6277, kl_w = 0.04608883, kl_loss = 27.342295 (36.914 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.58449\n",
      "INFO:tensorflow:loss = 56.814224, step = 1283 (38.693 sec)\n",
      "INFO:tensorflow:nll_loss = 55.382885, kl_w = 0.053996176, kl_loss = 26.508183 (38.693 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.56481\n",
      "INFO:tensorflow:loss = 56.07279, step = 1383 (38.989 sec)\n",
      "INFO:tensorflow:nll_loss = 54.398735, kl_w = 0.06317033, kl_loss = 26.500652 (38.989 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.62912\n",
      "INFO:tensorflow:loss = 52.393402, step = 1483 (38.036 sec)\n",
      "INFO:tensorflow:nll_loss = 50.44125, kl_w = 0.07378165, kl_loss = 26.458508 (38.037 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1564 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 58.71571.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-1564\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i am surprised that it was one of all time and the movie was <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is about it with a lot of you can see the movie <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-1564\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1565 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 57.186344, step = 1565\n",
      "INFO:tensorflow:nll_loss = 55.104565, kl_w = 0.08368247, kl_loss = 24.877115\n",
      "INFO:tensorflow:global_step/sec: 2.53972\n",
      "INFO:tensorflow:loss = 51.874313, step = 1665 (39.376 sec)\n",
      "INFO:tensorflow:nll_loss = 49.499, kl_w = 0.09738124, kl_loss = 24.391888 (39.375 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.5262\n",
      "INFO:tensorflow:loss = 49.297832, step = 1765 (39.585 sec)\n",
      "INFO:tensorflow:nll_loss = 46.62631, kl_w = 0.11304584, kl_loss = 23.632229 (39.585 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.52404\n",
      "INFO:tensorflow:loss = 49.489513, step = 1865 (39.619 sec)\n",
      "INFO:tensorflow:nll_loss = 46.454746, kl_w = 0.13086486, kl_loss = 23.190086 (39.619 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.58111\n",
      "INFO:tensorflow:loss = 49.638405, step = 1965 (38.743 sec)\n",
      "INFO:tensorflow:nll_loss = 46.187103, kl_w = 0.15101443, kl_loss = 22.854118 (38.743 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.66389\n",
      "INFO:tensorflow:loss = 53.515694, step = 2065 (37.539 sec)\n",
      "INFO:tensorflow:nll_loss = 49.86238, kl_w = 0.17364664, kl_loss = 21.038769 (37.539 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.6408\n",
      "INFO:tensorflow:loss = 54.809265, step = 2165 (37.867 sec)\n",
      "INFO:tensorflow:nll_loss = 50.960667, kl_w = 0.19887616, kl_loss = 19.351736 (37.867 sec)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:global_step/sec: 2.73232\n",
      "INFO:tensorflow:loss = 54.790527, step = 2265 (36.599 sec)\n",
      "INFO:tensorflow:nll_loss = 50.64678, kl_w = 0.22676536, kl_loss = 18.273281 (36.599 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 2346 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 54.296196.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-2346\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i give it a chance to watch it because the movie a great movie <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is one of those who is worth seeing it to be one <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-2346\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 2347 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 52.760197, step = 2347\n",
      "INFO:tensorflow:nll_loss = 48.295868, kl_w = 0.25161827, kl_loss = 17.742466\n",
      "INFO:tensorflow:global_step/sec: 2.52972\n",
      "INFO:tensorflow:loss = 50.47516, step = 2447 (39.531 sec)\n",
      "INFO:tensorflow:nll_loss = 45.3369, kl_w = 0.2842792, kl_loss = 18.074701 (39.531 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.83522\n",
      "INFO:tensorflow:loss = 53.61573, step = 2547 (35.271 sec)\n",
      "INFO:tensorflow:nll_loss = 48.386337, kl_w = 0.31937042, kl_loss = 16.374065 (35.270 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.78594\n",
      "INFO:tensorflow:loss = 53.734016, step = 2647 (35.895 sec)\n",
      "INFO:tensorflow:nll_loss = 47.99819, kl_w = 0.35663486, kl_loss = 16.083185 (35.895 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.65408\n",
      "INFO:tensorflow:loss = 52.913494, step = 2747 (37.678 sec)\n",
      "INFO:tensorflow:nll_loss = 46.756668, kl_w = 0.39571938, kl_loss = 15.558564 (37.678 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.75955\n",
      "INFO:tensorflow:loss = 51.78798, step = 2847 (36.238 sec)\n",
      "INFO:tensorflow:nll_loss = 45.35582, kl_w = 0.43618327, kl_loss = 14.746458 (36.238 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.75676\n",
      "INFO:tensorflow:loss = 55.95737, step = 2947 (36.275 sec)\n",
      "INFO:tensorflow:nll_loss = 49.290356, kl_w = 0.47751516, kl_loss = 13.961891 (36.275 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.5739\n",
      "INFO:tensorflow:loss = 57.011475, step = 3047 (38.851 sec)\n",
      "INFO:tensorflow:nll_loss = 50.37846, kl_w = 0.5191573, kl_loss = 12.776503 (38.851 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 3128 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 56.897583.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-3128\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i would have to watch it on the movie i gave it a 1 <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this one is a must see for those who is no time for it <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-3128\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 3129 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 51.762924, step = 3129\n",
      "INFO:tensorflow:nll_loss = 44.92703, kl_w = 0.55313194, kl_loss = 12.358528\n",
      "INFO:tensorflow:global_step/sec: 2.71682\n",
      "INFO:tensorflow:loss = 54.260815, step = 3229 (36.810 sec)\n",
      "INFO:tensorflow:nll_loss = 47.331028, kl_w = 0.5938731, kl_loss = 11.668799 (36.810 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.81256\n",
      "INFO:tensorflow:loss = 55.697197, step = 3329 (35.554 sec)\n",
      "INFO:tensorflow:nll_loss = 48.61201, kl_w = 0.6333619, kl_loss = 11.186632 (35.553 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.90732\n",
      "INFO:tensorflow:loss = 55.126804, step = 3429 (34.396 sec)\n",
      "INFO:tensorflow:nll_loss = 48.03051, kl_w = 0.6711373, kl_loss = 10.573535 (34.396 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.90276\n",
      "INFO:tensorflow:loss = 52.210747, step = 3529 (34.450 sec)\n",
      "INFO:tensorflow:nll_loss = 45.026436, kl_w = 0.7068222, kl_loss = 10.164243 (34.450 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.90162\n",
      "INFO:tensorflow:loss = 52.77495, step = 3629 (34.463 sec)\n",
      "INFO:tensorflow:nll_loss = 45.694237, kl_w = 0.7401343, kl_loss = 9.566792 (34.463 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.85165\n",
      "INFO:tensorflow:loss = 53.27659, step = 3729 (35.067 sec)\n",
      "INFO:tensorflow:nll_loss = 46.36541, kl_w = 0.7708882, kl_loss = 8.965216 (35.067 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.89205\n",
      "INFO:tensorflow:loss = 57.235584, step = 3829 (34.578 sec)\n",
      "INFO:tensorflow:nll_loss = 50.861748, kl_w = 0.79899096, kl_loss = 7.977356 (34.578 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 3910 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 55.6716.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-3910\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i guess i would give it a 7 10 10 i was a 10 <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this is a movie that is worth your time and not worth watching it <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-3910\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 3911 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 56.748856, step = 3911\n",
      "INFO:tensorflow:nll_loss = 49.84851, kl_w = 0.82004714, kl_loss = 8.414572\n",
      "INFO:tensorflow:global_step/sec: 2.81243\n",
      "INFO:tensorflow:loss = 52.065613, step = 4011 (35.558 sec)\n",
      "INFO:tensorflow:nll_loss = 45.291687, kl_w = 0.84334546, kl_loss = 8.0322075 (35.557 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.8621\n",
      "INFO:tensorflow:loss = 54.033203, step = 4111 (34.939 sec)\n",
      "INFO:tensorflow:nll_loss = 48.046844, kl_w = 0.8641271, kl_loss = 6.927638 (34.940 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.90148\n",
      "INFO:tensorflow:loss = 55.242, step = 4211 (34.465 sec)\n",
      "INFO:tensorflow:nll_loss = 49.40223, kl_w = 0.88253593, kl_loss = 6.6170354 (34.466 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.82148\n",
      "INFO:tensorflow:loss = 57.11499, step = 4311 (35.442 sec)\n",
      "INFO:tensorflow:nll_loss = 51.134407, kl_w = 0.89874285, kl_loss = 6.6543865 (35.442 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.76692\n",
      "INFO:tensorflow:loss = 58.921696, step = 4411 (36.141 sec)\n",
      "INFO:tensorflow:nll_loss = 53.445824, kl_w = 0.9129343, kl_loss = 5.9981017 (36.141 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.74626\n",
      "INFO:tensorflow:loss = 55.100544, step = 4511 (36.413 sec)\n",
      "INFO:tensorflow:nll_loss = 49.27398, kl_w = 0.92530197, kl_loss = 6.296933 (36.413 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.86371\n",
      "INFO:tensorflow:loss = 53.560604, step = 4611 (34.920 sec)\n",
      "INFO:tensorflow:nll_loss = 48.093174, kl_w = 0.93603605, kl_loss = 5.8410454 (34.920 sec)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Saving checkpoints for 4692 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 53.27417.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-4692\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i can't believe me i am a fan of this movie i have a <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: i am a huge fan of the film and i am a huge fan <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-4692\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 4693 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 57.35493, step = 4693\n",
      "INFO:tensorflow:nll_loss = 51.43254, kl_w = 0.94374704, kl_loss = 6.2754006\n",
      "INFO:tensorflow:global_step/sec: 2.64514\n",
      "INFO:tensorflow:loss = 54.67176, step = 4793 (37.806 sec)\n",
      "INFO:tensorflow:nll_loss = 49.468597, kl_w = 0.95196813, kl_loss = 5.46569 (37.806 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.64063\n",
      "INFO:tensorflow:loss = 55.025543, step = 4893 (37.870 sec)\n",
      "INFO:tensorflow:nll_loss = 49.672604, kl_w = 0.9590399, kl_loss = 5.581562 (37.870 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.81862\n",
      "INFO:tensorflow:loss = 52.16002, step = 4993 (35.478 sec)\n",
      "INFO:tensorflow:nll_loss = 47.08113, kl_w = 0.9651086, kl_loss = 5.2625027 (35.478 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.75623\n",
      "INFO:tensorflow:loss = 49.820942, step = 5093 (36.281 sec)\n",
      "INFO:tensorflow:nll_loss = 44.714054, kl_w = 0.97030604, kl_loss = 5.263172 (36.281 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.57\n",
      "INFO:tensorflow:loss = 52.983593, step = 5193 (38.910 sec)\n",
      "INFO:tensorflow:nll_loss = 47.976612, kl_w = 0.97474945, kl_loss = 5.1366844 (38.910 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.79816\n",
      "INFO:tensorflow:loss = 52.158203, step = 5293 (35.738 sec)\n",
      "INFO:tensorflow:nll_loss = 47.137726, kl_w = 0.9785427, kl_loss = 5.1305647 (35.738 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.79358\n",
      "INFO:tensorflow:loss = 54.938114, step = 5393 (35.796 sec)\n",
      "INFO:tensorflow:nll_loss = 50.48974, kl_w = 0.9817768, kl_loss = 4.530945 (35.796 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 5474 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 54.896923.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-5474\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i would recommend this movie to anyone who is one of the best movies <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: the film is a waste of time and money and not really worth seeing <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-5474\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 5475 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 53.68703, step = 5475\n",
      "INFO:tensorflow:nll_loss = 48.731552, kl_w = 0.98406756, kl_loss = 5.0357084\n",
      "INFO:tensorflow:global_step/sec: 2.71121\n",
      "INFO:tensorflow:loss = 49.05899, step = 5575 (36.885 sec)\n",
      "INFO:tensorflow:nll_loss = 44.22655, kl_w = 0.9864804, kl_loss = 4.8986664 (36.885 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.57059\n",
      "INFO:tensorflow:loss = 49.98869, step = 5675 (38.902 sec)\n",
      "INFO:tensorflow:nll_loss = 45.79241, kl_w = 0.98853207, kl_loss = 4.24496 (38.901 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.50806\n",
      "INFO:tensorflow:loss = 53.126938, step = 5775 (39.871 sec)\n",
      "INFO:tensorflow:nll_loss = 49.066444, kl_w = 0.9902755, kl_loss = 4.1003666 (39.871 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.64293\n",
      "INFO:tensorflow:loss = 54.40632, step = 5875 (37.837 sec)\n",
      "INFO:tensorflow:nll_loss = 50.162205, kl_w = 0.9917561, kl_loss = 4.279393 (37.838 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.55102\n",
      "INFO:tensorflow:loss = 54.316414, step = 5975 (39.200 sec)\n",
      "INFO:tensorflow:nll_loss = 50.099586, kl_w = 0.99301285, kl_loss = 4.246498 (39.199 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.73476\n",
      "INFO:tensorflow:loss = 51.17121, step = 6075 (36.566 sec)\n",
      "INFO:tensorflow:nll_loss = 47.014244, kl_w = 0.9940791, kl_loss = 4.1817265 (36.566 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.56428\n",
      "INFO:tensorflow:loss = 51.05745, step = 6175 (38.997 sec)\n",
      "INFO:tensorflow:nll_loss = 46.78045, kl_w = 0.99498355, kl_loss = 4.2985654 (38.997 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 6256 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 46.84692.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-6256\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i have to admit i was a huge fan of the movie the worst <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this is a very good movie about a couple of hours and i was <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-6256\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 6257 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 46.562756, step = 6257\n",
      "INFO:tensorflow:nll_loss = 42.23388, kl_w = 0.9956215, kl_loss = 4.347913\n",
      "INFO:tensorflow:global_step/sec: 2.37532\n",
      "INFO:tensorflow:loss = 46.756992, step = 6357 (42.101 sec)\n",
      "INFO:tensorflow:nll_loss = 42.44667, kl_w = 0.9962913, kl_loss = 4.326367 (42.101 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.73744\n",
      "INFO:tensorflow:loss = 49.9253, step = 6457 (36.530 sec)\n",
      "INFO:tensorflow:nll_loss = 45.74327, kl_w = 0.99685884, kl_loss = 4.1952076 (36.530 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.91067\n",
      "INFO:tensorflow:loss = 52.90568, step = 6557 (34.356 sec)\n",
      "INFO:tensorflow:nll_loss = 48.724194, kl_w = 0.9973398, kl_loss = 4.19264 (34.356 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.67925\n",
      "INFO:tensorflow:loss = 51.473373, step = 6657 (37.324 sec)\n",
      "INFO:tensorflow:nll_loss = 47.653976, kl_w = 0.99774724, kl_loss = 3.82802 (37.324 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.63528\n",
      "INFO:tensorflow:loss = 53.12049, step = 6757 (37.946 sec)\n",
      "INFO:tensorflow:nll_loss = 49.45507, kl_w = 0.99809235, kl_loss = 3.6724257 (37.947 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.74336\n",
      "INFO:tensorflow:loss = 53.40587, step = 6857 (36.452 sec)\n",
      "INFO:tensorflow:nll_loss = 49.752956, kl_w = 0.99838483, kl_loss = 3.6588216 (36.452 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.45859\n",
      "INFO:tensorflow:loss = 52.050613, step = 6957 (40.674 sec)\n",
      "INFO:tensorflow:nll_loss = 48.23701, kl_w = 0.9986324, kl_loss = 3.8188272 (40.674 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 7038 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 48.661674.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-7038\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i saw this movie last night and the last night the night the night <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: this movie is a waste of time and money and i don't think this <end>\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-7038\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 7039 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:loss = 54.646355, step = 7039\n",
      "INFO:tensorflow:nll_loss = 50.722137, kl_w = 0.99880695, kl_loss = 3.928906\n",
      "INFO:tensorflow:global_step/sec: 2.82371\n",
      "INFO:tensorflow:loss = 49.95899, step = 7139 (35.416 sec)\n",
      "INFO:tensorflow:nll_loss = 46.347992, kl_w = 0.9989899, kl_loss = 3.6146488 (35.416 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.91101\n",
      "INFO:tensorflow:loss = 46.17417, step = 7239 (34.352 sec)\n",
      "INFO:tensorflow:nll_loss = 42.397076, kl_w = 0.9991448, kl_loss = 3.7803278 (34.353 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.91122\n",
      "INFO:tensorflow:loss = 48.266136, step = 7339 (34.350 sec)\n",
      "INFO:tensorflow:nll_loss = 44.68577, kl_w = 0.999276, kl_loss = 3.5829618 (34.349 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.90287\n",
      "INFO:tensorflow:loss = 47.722324, step = 7439 (34.449 sec)\n",
      "INFO:tensorflow:nll_loss = 43.864983, kl_w = 0.999387, kl_loss = 3.8597066 (34.449 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.90696\n",
      "INFO:tensorflow:loss = 48.406315, step = 7539 (34.400 sec)\n",
      "INFO:tensorflow:nll_loss = 44.36548, kl_w = 0.99948114, kl_loss = 4.042934 (34.401 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.91007\n",
      "INFO:tensorflow:loss = 49.7955, step = 7639 (34.363 sec)\n",
      "INFO:tensorflow:nll_loss = 46.321495, kl_w = 0.9995608, kl_loss = 3.475533 (34.363 sec)\n",
      "INFO:tensorflow:global_step/sec: 2.89954\n",
      "INFO:tensorflow:loss = 51.267563, step = 7739 (34.488 sec)\n",
      "INFO:tensorflow:nll_loss = 48.019234, kl_w = 0.9996282, kl_loss = 3.2495384 (34.489 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 7820 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 54.30408.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp5q9zil3i/model.ckpt-7820\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Original: i love this film and i think it is one of the best films\n",
      "Reconstr: i would have to say that it was a good movie and my wife <end>\n",
      "\n",
      "Original: this movie is a waste of time and there is no point to watch it\n",
      "Reconstr: i would not recommend this movie to anyone who is one of the worst <end>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def inf_inp(test_strs):\n",
    "    x = [[PARAMS['word2idx'].get(w, 2) for w in s.split()] for s in test_strs]\n",
    "    x = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        x, PARAMS['max_len'], truncating='post', padding='post')\n",
    "    return x\n",
    "\n",
    "def demo(test_strs, pred_ids):\n",
    "    for s, pred in zip(test_strs, pred_ids):\n",
    "        print('\\nOriginal:', s)\n",
    "        print('Reconstr:', ' '.join([PARAMS['idx2word'].get(idx, '<unk>') for idx in pred]))\n",
    "\n",
    "\n",
    "test_strs = ['i love this film and i think it is one of the best films',\n",
    "             'this movie is a waste of time and there is no point to watch it']\n",
    "\n",
    "estimator = tf.estimator.Estimator(model_fn)\n",
    "\n",
    "for _ in range(PARAMS['n_epochs']):\n",
    "    estimator.train(tf.estimator.inputs.numpy_input_fn(\n",
    "        x = enc_inp,\n",
    "        y = {'dec_inp': dec_inp, 'dec_out': dec_out},\n",
    "        batch_size = PARAMS['batch_size'],\n",
    "        shuffle = True))\n",
    "    \n",
    "    pred_ids = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(\n",
    "        x = inf_inp(test_strs),\n",
    "        shuffle = False)))\n",
    "\n",
    "    demo(test_strs, pred_ids)\n",
    "    \n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
